package cn.bugstack.openai.session.defaults;

import cn.bugstack.openai.exception.OpenAiSdkException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import org.apache.commons.lang3.StringUtils;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.ProxySelector;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 代理切换
 *
 * @author chj
 * @date 2024/3/29
 **/
public class SwitchProxySelector extends ProxySelector {


    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(SwitchProxySelector.class);


    /**
     * 根据request返回
     */
    public List<Proxy> select(URI uri) {
        Proxy proxy = SwitchProxySelector.proxyThreadLocal.get();
        if (proxy == null) {
            proxy = Proxy.NO_PROXY;
        }
        LOGGER.message("使用代理")
                .context("uri", uri.toString())
                .context("proxy", proxy.type().name())
                .context("proxy", proxy.address());
        SwitchProxySelector.proxyThreadLocal.remove();
        return Collections.singletonList(proxy);

    }


    public void connectFailed(URI uri, SocketAddress sa, IOException ioe) {

    }

    public static final ThreadLocal<Proxy> proxyThreadLocal = new ThreadLocal<>();
    /**
     * proxy 模式
     */
    private static final Pattern PROXY_PATTERN = Pattern.compile("(socket|http)://(.*):(.*)");

    /**
     * 工厂方法 获取Proxy
     *
     * @param proxyString meta.proxy 中的 proxy 字符串
     * @return Proxy
     */
    public static Proxy getProxy(String proxyString) {
        if (proxyString == null || proxyString.isEmpty()) {
            return Proxy.NO_PROXY;
        }
        try {
            URI uri = new URI(proxyString);
            // 获取请求协议
            String protocol = uri.getScheme();
            String host = uri.getHost();
            if (StringUtils.isEmpty(protocol) || StringUtils.isEmpty(host)) {
                throw new OpenAiSdkException("代理地址有误，当前代理地址【" + proxyString + "】，参考代理URL格式：" +
                        "【http://username:password@host:port、https://username:password@host:port、http://host:port、https://host:port】");
            }
            int port = uri.getPort();
            if (port == -1) { // 如果未指定端口，则使用默认端口
                port = uri.getScheme().equalsIgnoreCase("https") ? 443 : 80;
            }
            Matcher matcher = PROXY_PATTERN.matcher(proxyString);
            if (matcher.matches()) {
                return switch (matcher.group(1)) {
                    case "socket" -> new Proxy(Proxy.Type.SOCKS, new InetSocketAddress(host, port));
                    case "http" -> new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port));
                    default -> Proxy.NO_PROXY;
                };
            }
        } catch (Exception e) {
            LOGGER.message("代理设置错误")
                    .context("proxyString", proxyString)
                    .exception(e).error();
        }
        return Proxy.NO_PROXY;
    }


}
